from typing import Tuple

import time
import torch
import random
import tianshou

import numpy as np
import gymnasium as gym
from gymnasium.spaces import Box

from pathlib import Path
from omegaconf import DictConfig, OmegaConf
from torch.utils.tensorboard import SummaryWriter
import sys, os


from State.buffer import VectorGCReplayBufferManager
from tianshou.data import ReplayBuffer
from Utils.logger import GCLogger
from Utils.flatten_dict_observation_wrapper import FlattenDictObservation, FlattenFactorObservation
from Initializers.flat_obs_normalize import FlatNormalization
from Utils.gc_env_wrapper import GCVectorEnv

REPO_PATH = Path(__file__).resolve().parents[1]


class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self


def clean_dict(_dict):
    for k, v in _dict.items():
        if v == "":  # encode empty string as None
            v = None
        if isinstance(v, dict):
            v = clean_dict(v)
        _dict[k] = v
    return AttrDict(_dict)


def init_logistics(config: DictConfig, wdb_run=None, saving: bool = True, training: bool = True,) \
        -> Tuple[gym.Env, tianshou.env.BaseVectorEnv, tianshou.env.BaseVectorEnv, GCLogger, AttrDict]:
    init_printing_format()
    torch.set_default_dtype(torch.float32)

    logger = None
    save_gif = config.save.save_gif_num
    config.exp_path = None
    if saving or save_gif:
        init_saving(config)
    if saving:
        # The logger logs to tensorboard, wandb if not None, and prints out to the log file
        logger = init_logger(config, wdb_run=wdb_run)

    # intializes a parallel environment for the train_env, see get_single_env to add environments
    train_env = init_env(config, config.env.num_train_envs) if training else None
    test_env = init_env(config, config.env.num_test_envs, render_mode="rgb_array" if save_gif else "human")
    # creates an extra environment to access environment parameters, though this is a bit redundant
    single_env = get_single_env(config)

    if training:
        init_seed(train_env, config.seed)
    init_seed(test_env, config.seed + 10000)

    # Adds in environment specific parameters into the config so they can be passed around more easily
    config = clean_dict(config)
    config.num_factors = single_env.num_factors
    config.goal_based = config.env.goal_based = single_env.goal_based
    config.factor_spaces = single_env.factor_spaces
    obs_space = single_env.observation_space
    assert isinstance(obs_space, Box)
    assert len(obs_space.shape) == 1
    # assumes one dimensional observations #TODO: image observations
    config.obs_size = single_env.observation_space.shape[0]
    norm = init_norm(single_env)

    return single_env, train_env, test_env, norm, logger, config


def init_saving(config: DictConfig) -> None:
    save_dir = Path(config.alt_path) if len(config.alt_path) > 0 else REPO_PATH

    info = config.info.replace(" ", "_")
    experiment_dirname = info + "_" + time.strftime("%Y_%m_%d_%H_%M_%S")
    exp_path = save_dir / "results" / config.sub_dirname / experiment_dirname
    exp_path.mkdir(parents=True)

    gif_path = exp_path / "gifs"
    if config.save.save_gif_num:
        gif_path.mkdir(parents=True)

    with open(exp_path / "config.yaml", "w") as fp:
        OmegaConf.save(config=config, f=fp.name)

    config.exp_path = exp_path
    config.replay_buffer_dir = None

    if config.save.save_replay_buffer:
        config.replay_buffer_dir = save_dir / "replay_buffer" / config.sub_dirname / experiment_dirname
        config.replay_buffer_dir.mkdir(parents=True)


def init_logger(config: DictConfig, wdb_run=None) -> GCLogger:
    writer = SummaryWriter(config.exp_path)
    logger = GCLogger(writer, wdb_run=wdb_run, train_interval=1, test_interval=1, update_interval=10, policy_interval=1, log_learn_her_keys=config.log.log_learn_her_keys)
    return logger


def init_loading(
        config: AttrDict,
        dynamics,
        graph_encoding,
        policy,
        buffer,
):
    load_config = config.load
    if load_config.load_dynamics is not None: # TODO: create alternate load path
        load_dynamics = REPO_PATH / load_config.load_dynamics
        if load_dynamics.is_file() and load_dynamics.exists():
            print("dynamics loaded", load_dynamics)
            dynamics.load_state_dict(torch.load(load_dynamics, map_location=config.device))

    if graph_encoding is not None and load_config.load_graph_encoding is not None:
        load_graph_encoding = REPO_PATH / load_config.load_graph_encoding
        if load_graph_encoding.is_file() and load_graph_encoding.exists():
            print("graph encoding loaded", load_graph_encoding)
            graph_encoding.load_state_dict(torch.load(load_graph_encoding, map_location=config.device))

    if load_config.load_policy is not None:
        load_policy = REPO_PATH / load_config.load_policy
        if load_policy.is_file() and load_policy.exists():
            print("policy loaded", load_policy)
            policy.load_state_dict(torch.load(load_policy, map_location=config.device))

    if load_config.load_replay_buffer is not None and buffer is not None and load_config.load_rpb:
        load_replay_buffer = REPO_PATH / load_config.load_replay_buffer
        if load_replay_buffer.is_file() and load_replay_buffer.exists():
            buffer = VectorGCReplayBufferManager.load_hdf5(load_replay_buffer)
            print("lower replay buffer loaded", load_replay_buffer)

    return buffer


def init_seed(env: tianshou.env.BaseVectorEnv, seed: int = 0) -> None:
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    # TODO: check if env has seed()
    env.seed(seed)


def init_printing_format() -> None:
    np.set_printoptions(threshold=3000, linewidth=120, precision=4, suppress=True)
    torch.set_printoptions(precision=4, sci_mode=False)


def to_device(dictionary, device):
    """
    place dict of tensors + dict to device recursively
    """
    new_dictionary = {}
    for key, val in dictionary.items():
        if isinstance(val, dict):
            new_dictionary[key] = to_device(val, device)
        elif isinstance(val, torch.Tensor):
            new_dictionary[key] = val.to(device)
        else:
            raise ValueError("Unknown value type {} for key {}".format(type(val), key))
    return new_dictionary

def get_single_env(config: DictConfig, render_mode='human', args=None, i=0) -> gym.Env:
    env_config = config.env
    env_name = env_config.env_name
    mini_behavior_config = env_config.mini_behavior
    igibson_config = env_config.igibson
    if env_name in mini_behavior_config:
        env_specific_config = mini_behavior_config[env_name]
        env_id = "MiniGrid-" + env_name + "-v0"
        kwargs = {"evaluate_graph": env_config.evaluate_graph,
                  "discrete_obs": mini_behavior_config.discrete_obs,
                  "room_size": env_specific_config.room_size,
                  "max_steps": env_specific_config.max_steps,
                  "use_stage_reward": env_specific_config.use_stage_reward,
                  "random_obj_pose": env_specific_config.random_obj_pose,
                  }
        env = gym.make(env_id, **kwargs)
        env = FlattenDictObservation(env)
        env.set_render_mode(render_mode)
    elif env_name == "test_environment":
        env = TestEnv()
    elif env_name == "igibson":
        igibson_config = OmegaConf.to_container(igibson_config, resolve=True)
        from igibson.envs.igibson_factor_obs_env import iGibsonFactorObsEnv
        env = iGibsonFactorObsEnv(
            config_file=igibson_config,
            mode="headless",
            action_timestep=1 / 10.0,
            physics_timestep=1 / 120.0,
        )
        env = FlattenDictObservation(env)
    else: # ac_env
        # assume name in ac_infer.Environment.initialize_environment
        # assumes args is not None
        # TODO: make sure record either does not initialize, or is stored in a different way
        if os.path.join(sys.path[0],"Causal", "ac_infer") not in sys.path: sys.path.append(os.path.join(sys.path[0],"Causal", "ac_infer"))
        from Causal.ac_infer.Environment.Environments.initialize_environment import initialize_environment
        from Causal.ac_infer.Hyperparam.read_config import read_config
        from Causal.ac_infer.Environment.environment import convert_env_rl
        environment_config = read_config(config.env.ac_env.ac_config_path)
        environment_config.environment.seed = config.seed + i
        environment_config.environment.render = config.env.render
        env, record = initialize_environment(environment_config.environment, environment_config.record)
        if env.goal_based: env.set_goal_params({"radius": config.policy.reward.target_goal_epsilon * (env.length + env.width) / 4}) # sets the radius for rendering

        env = convert_env_rl(env)
        env = FlattenFactorObservation(env)        

    return env

def init_norm(single_env):
    return FlatNormalization(single_env.all_names, single_env.object_range, single_env.object_range["Goal"])

def init_env(config: DictConfig, num_envs: int, render_mode='human') -> tianshou.env.BaseVectorEnv:
    # if config.env.render:
    #     assert num_envs == 1

    env_fns = [lambda: get_single_env(config, render_mode, i=i) for i in range(num_envs)]
    if num_envs == 1:
        return tianshou.env.DummyVectorEnv(env_fns)
    else:
        if config.env.goal_sampler:
            return GCVectorEnv(env_fns)
        else:
            return tianshou.env.SubprocVectorEnv(env_fns)